package 多线程.矩阵乘法;

import java.io.File;
import java.util.Date;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Matix2 {

    //初始化一个随机nxn阶矩阵
    public static int[][] initializationMatrix(int n){
        int[][] result = new int[n][n];
        for(int i = 0;i < n;i++){
            for(int j = 0;j < n;j++){
                result[i][j] = (int)(Math.random()*10); //采用随机函数随机生成1~10之间的数
            }
        }
        return result;
    }

    //蛮力法求解两个nxn和nxn阶矩阵相乘
    public static int[][] BruteForce(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                for(int k=0;k<n;k++){
                    result[i][j] += p[i][k]*q[k][j];
                }
            }
        }
        return result;
    }

    //分治法求解两个nxn和nxn阶矩阵相乘
    public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        //当n为2时，返回矩阵相乘结果
        if(n == 2){
            result = BruteForce(p,q,n);   //调用蛮力法
            return result;
        }

        //当n大于3时，采用采用分治法，递归求最终结果
        if(n > 2){
            int m = n/2;
            //将矩阵均分为四个子矩阵
            //[1,2]
            //[3,4]
            int[][] p1 = QuarterMatrix(p,n,1);
            int[][] p2 = QuarterMatrix(p,n,2);
            int[][] p3 = QuarterMatrix(p,n,3);
            int[][] p4 = QuarterMatrix(p,n,4);

            int[][] q1 = QuarterMatrix(q,n,1);
            int[][] q2 = QuarterMatrix(q,n,2);
            int[][] q3 = QuarterMatrix(q,n,3);
            int[][] q4 = QuarterMatrix(q,n,4);

            int[][] result1;
            int[][] result2;
            int[][] result3;
            int[][] result4;

            //递归
            result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);
            result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);
            result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);
            result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);


            result = TogetherMatrix(result1,result2,result3,result4,m);
        }
        return result;
    }

    //获取矩阵的四分之一，并决定返回哪一个四分之一
    public static int[][] QuarterMatrix(int[][] p,int n,int number){
        int rows = n/2;   //行数减半
        int cols = n/2;   //列数减半
        int[][] result = new int[rows][cols];
        switch(number){
            case 1 :
            {
                // result = new int[rows][cols];
                for(int i=0;i<rows;i++){
                    for(int j=0;j<cols;j++){
                        result[i][j] = p[i][j];
                    }
                }
                break;
            }

            case 2 :
            {
                // result = new int[rows][n-cols];
                for(int i=0;i<rows;i++){
                    for(int j=0;j<n-cols;j++){
                        result[i][j] = p[i][j+cols];
                    }
                }
                break;
            }

            case 3 :
            {
                // result = new int[n-rows][cols];
                for(int i=0;i<n-rows;i++){
                    for(int j=0;j<cols;j++){
                        result[i][j] = p[i+rows][j];
                    }
                }
                break;
            }

            case 4 :
            {
                // result = new int[n-rows][n-cols];
                for(int i=0;i<n-rows;i++){
                    for(int j=0;j<n-cols;j++){
                        result[i][j] = p[i+rows][j+cols];
                    }
                }
                break;
            }

            default:
                break;
        }

        return result;
    }

    //把均分为四分之一的矩阵，聚合成一个矩阵，其中矩阵a,b,c,d分别对应原完整矩阵的四分中1、2、3、4
    public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
        int[][] result = new int[2*n][2*n];
        for(int i=0;i<2*n;i++){
            for(int j=0;j<2*n;j++){
                if(i<n){//右上
                    if(j<n){//左上
                        result[i][j] = a[i][j];
                    }
                    else
                        result[i][j] = b[i][j-n];
                }
                else{
                    if(j<n){//左下
                        result[i][j] = c[i-n][j];
                    }
                    else{//右下
                        result[i][j] = d[i-n][j-n];
                    }
                }
            }
        }

        return result;
    }


    //求两个矩阵相加结果
    public static int[][] AddMatrix(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = p[i][j]+q[i][j];
            }
        }
        return result;
    }

    //控制台输出矩阵
    public static void PrintfMatrix(int[][] matrix,int n){
        for(int i=0;i<n;i++){
            System.out.println();
            for(int j=0;j<n;j++){
                System.out.print("\t");
                System.out.print(matrix[i][j]);
            }
        }
    }
    public static int[][] multiply(int[][] p, int[][] q, int n) {
        int[][] result = new int[n][n];
        int numThreads = Runtime.getRuntime().availableProcessors();
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);

        for (int i = 0; i < n; i++) {
            final int row = i;
            executor.execute(() -> {
                for (int j = 0; j < n; j++) {
                    for (int k = 0; k < n; k++) {
                        result[row][j] += p[row][k] * q[k][j];
                    }
                }
            });
        }

        executor.shutdown();
        while (!executor.isTerminated()) {
            // 等待所有任务执行完成
        }

        return result;
    }
    public static int[][] duoxiancheng(int[][] p, int[][] q, int n) {
        int result[][] = new int[n][n];
        int numThreads = Runtime.getRuntime().availableProcessors(); // 获取CPU核心数
        Thread[] threads = new Thread[numThreads];
        int m = n / numThreads;

        for (int i = 0; i < numThreads; i++) {
            int finalI = i;
            threads[i] = new Thread(() -> {
                int startRow = finalI * m;
                int endRow = (finalI + 1) * m;

                for (int j = startRow; j < endRow; j++) {
                    for (int k = 0; k < n; k++) {
                        result[j][k] = 0;
                        for (int r = 0; r < n; r++) {
                            result[j][k] += p[j][r] * q[r][k];
                        }
                    }
                }
            });
            threads[i].start();
        }
        boolean flag =true;
        // 等待所有线程完成计算
        while(flag){

            for(Thread thread:threads){
                if(thread.isAlive()){
                    break;
                }
                flag=false;
            }
        }

        return result;
    }

    public static void main(String args[]){
        long start1,end1,start2,end2,start3,end3;
        int n =8;
        int[][] p = initializationMatrix(n);//初始化随机8*8的矩阵
        int[][] q = initializationMatrix(n);
//        System.out.print("矩阵p初始化值为：");
//        PrintfMatrix(p,n);//打印矩阵
//        System.out.println();
//        System.out.print("矩阵q初始化值为：");
//        PrintfMatrix(q,n);
        System.out.println();

        start1 = System.currentTimeMillis();
        int[][] bf_result = BruteForce(p,q,n);//蛮力法计算矩阵相乘
        end1 = System.currentTimeMillis();
        System.out.println();
        System.out.print("蛮力法计算矩阵p*q结果为：");
        PrintfMatrix(bf_result,n);
        System.out.println();


        start2 = System.currentTimeMillis();
        int[][] dac_result = DivideAndConquer(p,q,n);//分治法计算矩阵相乘
        end2 = System.currentTimeMillis();
        System.out.println();
        System.out.print("分治法计算矩阵p*q结果为：");
        PrintfMatrix(dac_result,n);
        System.out.println();

        start3 = System.currentTimeMillis();
        int[][] result = multiply(p,q,n);//分治法计算矩阵相乘
//        int[][] result = 多线程(p,q,n);//分治法计算矩阵相乘
        end3 = System.currentTimeMillis();
        System.out.println();
        System.out.print("多线程计算矩阵p*q结果为：");
        PrintfMatrix(result,n);
        System.out.println();


        System.out.println("蛮力法开始时间：" + start1);
        System.out.println("蛮力法结束时间：" + end1);
        System.out.println("蛮力法运行时间：" + (end1 - start1) + "(ms)\n");

        System.out.println("分治法开始时间：" + start2);
        System.out.println("分治法结束时间：" + end2);
        System.out.println("分治法运行时间：" + (end2 - start2) + "(ms)\n");

        System.out.println("多线程开始时间：" + start3);
        System.out.println("多线程结束时间：" + end3);
        System.out.println("多线程运行时间：" + (end3 - start3) + "(ms)\n");
    }
}
